# coding=utf-8
# Copyright 2019-present, the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Training DistilBERT.
"""
import os
import argparse
import pickle
import json
import shutil
import numpy as np
import torch

from transformers import BertTokenizer, BertForMaskedLM, RobertaTokenizer, RobertaForMaskedLM
from transformers import DistilBertForMaskedLM, DistilBertConfig

from distiller import Distiller
from utils import git_log, logger, init_gpu_params, set_seed
from dataset import Dataset


def main():
    parser = argparse.ArgumentParser(description="Training")

    parser.add_argument("--dump_path", type=str, required=True,
                        help="The output directory (log, checkpoints, parameters, etc.)")
    parser.add_argument("--data_file", type=str, required=True,
                        help="The binarized file (tokenized + tokens_to_ids) and grouped by sequence.")
    parser.add_argument("--token_counts", type=str, required=True,
                        help="The token counts in the data_file for MLM.")
    parser.add_argument("--force", action='store_true',
                        help="Overwrite dump_path if it already exists.")

    parser.add_argument("--vocab_size", default=30522, type=int,
                        help="The vocabulary size.")
    parser.add_argument("--max_position_embeddings", default=512, type=int,
                        help="Maximum sequence length we can model (including [CLS] and [SEP]).")
    parser.add_argument("--sinusoidal_pos_embds", action='store_false',
                        help="If true, the position embeddings are simply fixed with sinusoidal embeddings.")
    parser.add_argument("--n_layers", default=6, type=int,
                        help="Number of Transformer blocks.")
    parser.add_argument("--n_heads", default=12, type=int,
                        help="Number of heads in the self-attention module.")
    parser.add_argument("--dim", default=768, type=int,
                        help="Dimension through the network. Must be divisible by n_heads")
    parser.add_argument("--hidden_dim", default=3072, type=int,
                        help="Intermediate dimension in the FFN.")
    parser.add_argument("--dropout", default=0.1, type=float,
                        help="Dropout.")
    parser.add_argument("--attention_dropout", default=0.1, type=float,
                        help="Dropout in self-attention.")
    parser.add_argument("--activation", default='gelu', type=str,
                        help="Activation to use in self-attention")
    parser.add_argument("--tie_weights_", action='store_false',
                        help="If true, we tie the embeddings matrix with the projection over the vocabulary matrix. Default is true.")

    parser.add_argument("--from_pretrained_weights", default=None, type=str,
                        help="Load student initialization checkpoint.")
    parser.add_argument("--from_pretrained_config", default=None, type=str,
                        help="Load student initialization architecture config.")
    parser.add_argument("--teacher_type", default="bert", choices=["bert", "roberta"],
                        help="Teacher type (BERT, RoBERTa).")
    parser.add_argument("--teacher_name", default="bert-base-uncased", type=str,
                        help="The teacher model.")

    parser.add_argument("--temperature", default=2., type=float,
                        help="Temperature for the softmax temperature.")
    parser.add_argument("--alpha_ce", default=0.5, type=float,
                        help="Linear weight for the distillation loss. Must be >=0.")
    parser.add_argument("--alpha_mlm", default=0.5, type=float,
                        help="Linear weight for the MLM loss. Must be >=0.")
    parser.add_argument("--alpha_mse", default=0.0, type=float,
                        help="Linear weight of the MSE loss. Must be >=0.")
    parser.add_argument("--alpha_cos", default=0.0, type=float,
                        help="Linear weight of the cosine embedding loss. Must be >=0.")
    parser.add_argument("--mlm_mask_prop", default=0.15, type=float,
                        help="Proportion of tokens for which we need to make a prediction.")
    parser.add_argument("--word_mask", default=0.8, type=float,
                        help="Proportion of tokens to mask out.")
    parser.add_argument("--word_keep", default=0.1, type=float,
                        help="Proportion of tokens to keep.")
    parser.add_argument("--word_rand", default=0.1, type=float,
                        help="Proportion of tokens to randomly replace.")
    parser.add_argument("--mlm_smoothing", default=0.7, type=float,
                        help="Smoothing parameter to emphasize more rare tokens (see XLM, similar to word2vec).")
    parser.add_argument("--restrict_ce_to_mask", action='store_true',
                        help="If true, compute the distilation loss only the [MLM] prediction distribution.")

    parser.add_argument("--n_epoch", type=int, default=3,
                        help="Number of pass on the whole dataset.")
    parser.add_argument("--batch_size", type=int, default=5,
                        help="Batch size (for each process).")
    parser.add_argument("--tokens_per_batch", type=int, default=-1,
                        help="If specified, modify the batches so that they have approximately this number of tokens.")
    parser.add_argument("--shuffle", action='store_false',
                        help="If true, shuffle the sequence order. Default is true.")
    parser.add_argument("--group_by_size", action='store_false',
                        help="If true, group sequences that have similar length into the same batch. Default is true.")

    parser.add_argument("--gradient_accumulation_steps", type=int, default=50,
                        help="Gradient accumulation for larger training batches.")
    parser.add_argument("--warmup_prop", default=0.05, type=float,
                        help="Linear warmup proportion.")
    parser.add_argument("--weight_decay", default=0.0, type=float,
                        help="Weight deay if we apply some.")
    parser.add_argument("--learning_rate", default=5e-4, type=float,
                        help="The initial learning rate for Adam.")
    parser.add_argument("--adam_epsilon", default=1e-6, type=float,
                        help="Epsilon for Adam optimizer.")
    parser.add_argument("--max_grad_norm", default=5.0, type=float,
                        help="Max gradient norm.")
    parser.add_argument("--initializer_range", default=0.02, type=float,
                        help="Random initialization range.")

    parser.add_argument('--fp16', action='store_true',
                        help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit")
    parser.add_argument('--fp16_opt_level', type=str, default='O1',
                        help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
                             "See details at https://nvidia.github.io/apex/amp.html")
    parser.add_argument("--n_gpu", type=int, default=1,
                        help="Number of GPUs in the node.")
    parser.add_argument("--local_rank", type=int, default=-1,
                        help="Distributed training - Local rank")
    parser.add_argument("--seed", type=int, default=56,
                        help="Random seed")

    parser.add_argument("--log_interval", type=int, default=500,
                        help="Tensorboard logging interval.")
    parser.add_argument("--checkpoint_interval", type=int, default=4000,
                        help="Checkpoint interval.")
    args = parser.parse_args()


    ## ARGS ##
    init_gpu_params(args)
    set_seed(args)
    if args.is_master:
        if os.path.exists(args.dump_path):
            if not args.force:
                raise ValueError(f'Serialization dir {args.dump_path} already exists, but you have not precised wheter to overwrite it'
                                   'Use `--force` if you want to overwrite it')
            else:
                shutil.rmtree(args.dump_path)

        if not os.path.exists(args.dump_path):
            os.makedirs(args.dump_path)
        logger.info(f'Experiment will be dumped and logged in {args.dump_path}')


        ### SAVE PARAMS ###
        logger.info(f'Param: {args}')
        with open(os.path.join(args.dump_path, 'parameters.json'), 'w') as f:
            json.dump(vars(args), f, indent=4)
        git_log(args.dump_path)
    assert (args.from_pretrained_weights is None and args.from_pretrained_config is None) or \
           (args.from_pretrained_weights is not None and args.from_pretrained_config is not None)


    ### TOKENIZER ###
    if args.teacher_type == 'bert':
        tokenizer = BertTokenizer.from_pretrained(args.teacher_name)
    elif args.teacher_type == 'roberta':
        tokenizer = RobertaTokenizer.from_pretrained(args.teacher_name)
    special_tok_ids = {}
    for tok_name, tok_symbol in tokenizer.special_tokens_map.items():
        idx = tokenizer.all_special_tokens.index(tok_symbol)
        special_tok_ids[tok_name] = tokenizer.all_special_ids[idx]
    logger.info(f'Special tokens {special_tok_ids}')
    args.special_tok_ids = special_tok_ids


    ## DATA LOADER ##
    logger.info(f'Loading data from {args.data_file}')
    with open(args.data_file, 'rb') as fp:
        data = pickle.load(fp)


    assert os.path.isfile(args.token_counts)
    logger.info(f'Loading token counts from {args.token_counts} (already pre-computed)')
    with open(args.token_counts, 'rb') as fp:
        counts = pickle.load(fp)
        assert len(counts) == args.vocab_size
    token_probs = np.maximum(counts, 1) ** -args.mlm_smoothing
    for idx in special_tok_ids.values():
        token_probs[idx] = 0.  # do not predict special tokens
    token_probs = torch.from_numpy(token_probs)


    train_dataloader = Dataset(params=args, data=data)
    logger.info(f'Data loader created.')


    ## STUDENT ##
    if args.from_pretrained_weights is not None:
        assert os.path.isfile(args.from_pretrained_weights)
        assert os.path.isfile(args.from_pretrained_config)
        logger.info(f'Loading pretrained weights from {args.from_pretrained_weights}')
        logger.info(f'Loading pretrained config from {args.from_pretrained_config}')
        stu_architecture_config = DistilBertConfig.from_json_file(args.from_pretrained_config)
        stu_architecture_config.output_hidden_states = True
        student = DistilBertForMaskedLM.from_pretrained(args.from_pretrained_weights,
                                                        config=stu_architecture_config)
    else:
        args.vocab_size_or_config_json_file = args.vocab_size
        stu_architecture_config = DistilBertConfig(**vars(args), output_hidden_states=True)
        student = DistilBertForMaskedLM(stu_architecture_config)


    if args.n_gpu > 0:
        student.to(f'cuda:{args.local_rank}')
    logger.info(f'Student loaded.')


    ## TEACHER ##
    if args.teacher_type == 'bert':
        teacher = BertForMaskedLM.from_pretrained(args.teacher_name, output_hidden_states=True)
    elif args.teacher_type == 'roberta':
        teacher = RobertaForMaskedLM.from_pretrained(args.teacher_name, output_hidden_states=True)
    if args.n_gpu > 0:
        teacher.to(f'cuda:{args.local_rank}')
    logger.info(f'Teacher loaded from {args.teacher_name}.')

    ## DISTILLER ##
    torch.cuda.empty_cache()
    distiller = Distiller(params=args,
                          dataloader=train_dataloader,
                          token_probs=token_probs,
                          student=student,
                          teacher=teacher)
    distiller.train()
    logger.info("Let's go get some drinks.")


if __name__ == "__main__":
    main()
